from cgi import parse_multipart
from turtle import forward
import torch
import torch.nn as nn
from torch.autograd import Variable

from collections import OrderedDict
import numpy as np
import json



class Summary:

    def __init__(self, model, input_size, batch_size=-1, device="cuda"):
        self.model = model
        self.batch_size = batch_size

        device = device.lower()
        assert device in [
            "cuda",
            "cpu",
        ], "Input device is not valid, please specify 'cuda' or 'cpu'"

        if device == "cuda" and torch.cuda.is_available():
            dtype = torch.cuda.FloatTensor
        else:
            dtype = torch.FloatTensor
        # multiple inputs to the network
        if isinstance(input_size, tuple):
            input_size = [input_size]
        # batch_size of 2 for batchnorm
        self.input = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
        # print(type(x[0]))
        self.summary = self.forward()


    def forward(self):

        def register_hook(module):
            def hook(module, input, output):
                class_name = str(module.__class__).split(".")[-1].split("'")[0]
                module_idx = len(summary)

                m_key = "%s-%i" % (class_name, module_idx + 1)
                summary[m_key] = OrderedDict()
                summary[m_key]["input_shape"] = list(input[0].size())
                summary[m_key]["input_shape"][0] = self.batch_size
                if isinstance(output, (list, tuple)):
                    summary[m_key]["output_shape"] = [
                        [-1] + list(o.size())[1:] for o in output
                    ]
                else:
                    summary[m_key]["output_shape"] = list(output.size())
                    summary[m_key]["output_shape"][0] = self.batch_size

                params = 0
                if hasattr(module, "weight") and hasattr(module.weight, "size"):
                    params += torch.prod(torch.LongTensor(list(module.weight.size())))
                    summary[m_key]["trainable"] = module.weight.requires_grad
                if hasattr(module, "bias") and hasattr(module.bias, "size"):
                    params += torch.prod(torch.LongTensor(list(module.bias.size())))
                if type(params) == torch.Tensor:
                    params = params.tolist()
                summary[m_key]["nb_params"] = params

            if (
                not isinstance(module, nn.Sequential)
                and not isinstance(module, nn.ModuleList)
                and not (module == self.model)
            ):
                hooks.append(module.register_forward_hook(hook))
        
        # create properties
        summary = OrderedDict()
        hooks = []

        # register hook
        self.model.apply(register_hook)

        # make a forward pass
        # print(x.shape)
        self.model(*self.input)

        # remove these hooks
        for h in hooks:
            h.remove()
        
        return summary


    def save(self, path='summary.json'):
        with open(path, 'w') as f:
            json.dump(self.summary, f)


    def print(self):
        print("----------------------------------------------------------------")
        line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
        print(line_new)
        print("================================================================")
        total_params = 0
        total_output = 0
        trainable_params = 0
        for layer in summary:
            # input_shape, output_shape, trainable, nb_params
            line_new = "{:>20}  {:>25} {:>15}".format(
                layer,
                str(summary[layer]["output_shape"]),
                "{0:,}".format(summary[layer]["nb_params"]),
            )
            total_params += summary[layer]["nb_params"]
            total_output += np.prod(summary[layer]["output_shape"])
            if "trainable" in summary[layer]:
                if summary[layer]["trainable"] == True:
                    trainable_params += summary[layer]["nb_params"]
            print(line_new)

        # assume 4 bytes/number (float on cuda).
        total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
        total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
        total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
        total_size = total_params_size + total_output_size + total_input_size

        print("================================================================")
        print("Total params: {0:,}".format(total_params))
        print("Trainable params: {0:,}".format(trainable_params))
        print("Non-trainable params: {0:,}".format(total_params - trainable_params))
        print("----------------------------------------------------------------")
        print("Input size (MB): %0.2f" % total_input_size)
        print("Forward/backward pass size (MB): %0.2f" % total_output_size)
        print("Params size (MB): %0.2f" % total_params_size)
        print("Estimated Total Size (MB): %0.2f" % total_size)
        print("----------------------------------------------------------------")


if __name__ == "__main__":
    from torchvision.models import resnet152
    import time

    model = resnet152(True)

    input_size = (3, 64, 64)
    t = time.time()
    print('==========time(start):', 0)
    Summary(model, input_size).save()
    print('==========time(summary): ', time.time() - t)
    with open('model.log', 'w') as f:
        f.write(str(model))
    print('==========time(model): ', time.time() - t)
